Skip to content

Linear predict#87

Merged
hmgaudecker merged 77 commits intomainfrom
linear-predict
Mar 18, 2026
Merged

Linear predict#87
hmgaudecker merged 77 commits intomainfrom
linear-predict

Conversation

@hmgaudecker
Copy link
Member

Add linear Kalman predict fast path

Fixes #36.

Summary

  • Add linear_kalman_predict that uses direct matrix algebra (F @ x + c) instead of the unscented sigma-point transform, for models where all factors use linear or constant transition functions
  • maximization_inputs.py auto-selects the fast path via is_all_linear() — no API changes needed
  • Refactor likelihood functions to accept a generic predict_func callable instead of hardcoded kalman_predict + transition_func

Benchmark results

Tested on health-cognition (no_feedback_to_investments_linear, 4 latent factors, GPU 8 GiB):

om.Constraints (unscented) linear-predict
GPU: per iter (100 iters) 8.87s 8.36s
Speedup ~6%
JIT warmup 43.3s 40.4s
GPU memory Higher (OOMs with ~5 GiB free) Lower (runs with ~5 GiB free)
CPU: per iter (10 iters) 109.6s 107.9s (~1.6%, within noise)

The main benefit is reduced GPU memory usage — the unscented transform generates 2n+1 sigma points which are expensive to differentiate through, while the linear path uses a single matrix multiply. On a small model (4 factors), the speed gain is modest (~6% GPU), but the memory reduction is the difference between fitting on GPU vs OOMing when memory is constrained.

Test plan

  • Added unit tests for linear_kalman_predict and is_all_linear in test_kalman_filters.py
  • All 351 existing tests pass
  • Benchmarked against om.Constraints on real estimation task

hmgaudecker and others added 30 commits January 8, 2026 18:53
Introduce strongly-typed dataclasses for model configuration:
- Dimensions, Labels, Anchoring, EstimationOptions, TransitionInfo
- FactorEndogenousInfo, EndogenousFactorsInfo

This improves type safety and enables IDE autocompletion while keeping
user-facing model_dict as a plain dictionary.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Replace dict fields with frozendict in frozen dataclasses to ensure
true immutability:
- Labels.aug_periods_to_periods
- Labels.aug_stages_to_stages
- Anchoring.outcomes
- TransitionInfo.param_names, individual_functions, function_names
- EndogenousFactorsInfo.aug_periods_to_aug_period_meas_types, factor_info

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Update process_model() to return a ProcessedModel frozen dataclass
and update all consumers to use attribute access instead of dict access.

This provides:
- Better type safety with explicit typed fields
- Immutability via frozen dataclass
- IDE autocomplete support
- Clear documentation of the model structure

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…c so that config.TEST_DATA_DIR is valid also for skillmodels the package (as opposed to the project).
The filtered_states DataFrame and params index both use aug_period as the
period identifier, not period. This fixes KeyError when calling
decompose_measurement_variance.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
hmgaudecker and others added 15 commits March 15, 2026 14:47
Remove list from loc type union, convert callers to tuple().
Update anchoring test expectations from list to tuple.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The viz code assumed states DataFrames always have `aug_period` as a
column, but pre-computed states (e.g. from health-cognition) may carry
`period` in the index instead. Add `_normalize_states_columns` to
promote index levels and rename `period` → `aug_period` when needed.

Also document the period vs aug_period convention in CLAUDE.md.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@codecov
Copy link

codecov bot commented Mar 18, 2026

Codecov Report

❌ Patch coverage is 98.67550% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 96.91%. Comparing base (023775e) to head (5c63203).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/skillmodels/maximization_inputs.py 60.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #87      +/-   ##
==========================================
+ Coverage   96.86%   96.91%   +0.05%     
==========================================
  Files          57       57              
  Lines        4809     4952     +143     
==========================================
+ Hits         4658     4799     +141     
- Misses        151      153       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@janosg janosg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should compare the speed of just the update step to see if the linear one is implemented efficiently. Without a detailed analysis I would expect the linear update step to be at least twice as fast as the unscented one. Of course, in a model with few factors or many measurements per factor, the unscented predict might not have been the bottleneck anyways.

Comment on lines +298 to +309
for i, factor in enumerate(latent_factors):
if i in constant_factor_indices:
row = jnp.zeros(n_all_factors).at[i].set(1.0)
f_rows.append(row)
c_vals.append(0.0)
else:
coeffs = trans_coeffs[factor]
f_rows.append(coeffs[:-1])
c_vals.append(coeffs[-1])

f_mat = jnp.stack(f_rows) # (n_latent, n_all)
c_vec = jnp.array(c_vals) # (n_latent,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks suboptimal but maybe Jax is smart enough at compiling the small array creation away. Have you tried different implementations?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed that Jax is smart enough. But I kept a more idiomatic version from the experiments and added a note.

@hmgaudecker
Copy link
Member Author

hmgaudecker commented Mar 18, 2026

Re: the question about a linear update step —

The measurement model in skillmodels is always linear (states @ loadings + controls @ control_params), so kalman_update already is the exact linear update. The unscented transform only appears in the predict step (to propagate states through potentially nonlinear transition functions). There is no "unscented update" counterpart — the same kalman_update runs regardless of whether transitions are linear or nonlinear.

The QR decomposition in the update operates on an (n_states+1) × (n_states+1) matrix whose structure depends only on the measurement model (loadings, meas_sd), not on the transition model. So a linear_kalman_update would be identical to the current function — no optimization opportunity here.

(this is Claude, obviously, but it does appear plausible without checking deeply)

Base automatically changed from om.Constraints to main March 18, 2026 13:03
@hmgaudecker hmgaudecker merged commit 2a77ee4 into main Mar 18, 2026
6 checks passed
@hmgaudecker hmgaudecker deleted the linear-predict branch March 18, 2026 13:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Use a linear predict when possible

2 participants